[ pytorch ] 基本使用丨2. 训练好的模型参数的保存以及调用丨 您所在的位置:网站首页 python self参数如何调用 [ pytorch ] 基本使用丨2. 训练好的模型参数的保存以及调用丨

[ pytorch ] 基本使用丨2. 训练好的模型参数的保存以及调用丨

2024-06-14 02:43| 来源: 网络整理| 查看: 265

保存与调用 def modelfunc(nn.Module): # 之前定义好的模型 def __init__(self, class_num=3): super(modelfunc, self).__init__() ... def forward(self,x): ... return x # 由于pytorch没有像keras那样有保存模型结构的API,因此,每次load之前必须找到模型的结构。 model_object = modelfunc(class_num=3) # 导入模型结构 # 保存和加载整个模型 torch.save(model_object, 'model.pth') model = torch.load('model.pth') # 仅保存和加载模型参数 torch.save(model_object.state_dict(), 'params.pth') model_object.load_state_dict(torch.load('params.pth')) torch.load 的输出: # 保存和加载整个模型 torch.save(model_object, 'model.pth') model = torch.load('model.pth') print(model) >>>【结果】 modelfunc( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): Bottleneck( (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (downsample): Sequential( (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): Bottleneck( (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) ) (2): Bottleneck( (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) ) ) (layer2): Sequential( (0): Bottleneck( (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): Bottleneck( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) ) (2): Bottleneck( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) ) (3): Bottleneck( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) ) ) ) model_object.load_state_dict(torch.load(‘params.pth’)) 参数的输出 # 仅保存和加载模型参数 torch.save(model_object.state_dict(), 'params.pth') dic = torch.load('params.pth') model_object.load_state_dict(dic) print(dic) >>>【结果】 OrderedDict([('conv1.weight', tensor([[[[ 1.3335e-02, 1.4664e-02, -1.5351e-02, ..., -4.0896e-02, -4.3034e-02, -7.0755e-02], [ 4.1205e-03, 5.8477e-03, 1.4948e-02, ..., 2.2060e-03, -2.0912e-02, -3.8517e-02], [ 2.2331e-02, 2.3595e-02, 1.6120e-02, ..., 1.0281e-01, 6.2641e-02, 5.1977e-02], ..., ('bn1.weight', tensor([ 2.3888e-01, 2.9136e-01, 3.1615e-01, 2.7122e-01, 2.1731e-01, 3.0903e-01, 2.2937e-01, 2.3086e-01, 2.1129e-01, 2.8054e-01, 1.9923e-01, 3.1894e-01, 1.7955e-01, 1.1246e-08, 1.9704e-01, 2.0996e-01, 2.4317e-01, 2.1697e-01, 1.9415e-01, 3.1569e-01, 1.9648e-01, 2.3214e-01, 2.1962e-01, 2.1633e-01, 2.4357e-01, 2.9683e-01, 2.3852e-01, 2.1162e-01, 1.4492e-01, 2.9388e-01, 2.2911e-01, 9.2716e-02, 4.3334e-01, 2.0782e-01, 2.7990e-01, 3.5804e-01, 2.9315e-01, 2.5306e-01, 2.4210e-01, 2.1755e-01, 3.8645e-01, 2.1003e-01, 3.6805e-01, 3.3724e-01, 5.0826e-01, 1.9341e-01, 2.3914e-01, 2.6652e-01, 3.9020e-01, 1.9840e-01, 2.1694e-01, 2.6666e-01, 4.9806e-01, 2.3553e-01, 2.1349e-01, 2.5951e-01, 2.3547e-01, 1.7579e-01, 4.5354e-01, 1.7102e-01, 2.4903e-01, 2.5148e-01, 3.8020e-01, 1.9665e-01])), ('bn1.bias', tensor([ 2.2484e-01, 6.0617e-01, 1.2483e-02, 1.3270e-01, 1.8030e-01, 1.4739e-01, 1.7430e-01, 1.9023e-01, 2.3226e-01, 2.0082e-01, 1.2834e-01, -2.1285e-01, 1.5065e-01, -3.9217e-08, 2.4985e-01, 2.0454e-01, 5.4934e-01, 2.1021e-01, 2.2505e-01, 4.6484e-01, 2.3888e-01, 2.0442e-01, 2.1546e-01, 6.6194e-01, 2.2755e-01, 6.6069e-01, 2.0587e-01, 1.9292e-01, 1.1195e-01, 3.3785e-01, 1.2393e-01, 4.1079e-02, 7.7150e-01, 2.6964e-01, 3.3347e-01, 5.7908e-01, 1.5026e-01, 1.7534e-01, 1.9429e-01, 1.7248e-01, 8.0577e-01, 2.3693e-01, -4.3369e-01, 8.4813e-01, -3.7857e-01, 2.4787e-01, 1.8101e-01, 3.2949e-01, -2.8598e-01, 2.2717e-01, 2.6168e-01, 5.7609e-02, -5.0320e-01, 1.5704e-01, 1.7890e-01, 2.8114e-01, 4.2167e-01, -9.7650e-02, -3.1231e-01, -2.5637e-02, 8.8566e-02, 1.8052e-01, 8.3045e-01, 2.5015e-01])), ('bn1.running_mean', tensor([ 2.8781e-02, 1.0830e-01, 2.6812e-01, -4.7955e-02, -2.7350e-02, -1.2350e-02, -2.8534e-02, 3.8390e-02, 8.6643e-03, 1.1076e-01, -1.6231e-02, -7.1499e-01, 5.7644e-02, -5.1895e-07, -1.9860e-02, 6.5988e-03, 4.9869e-01, -3.4726e-02, -2.2373e-02, -6.4198e-01, 3.3326e-02, 6.5970e-02, 3.1869e-02, 3.1863e-01, 3.7692e-02, 4.9075e-01, 3.0402e-02, -6.5330e-02, -2.4589e-02, 4.3018e-01, -6.3207e-02, 3.6987e-02, -7.9438e-01, 3.7037e-02, 8.1242e-01, -8.8931e-01, -3.4412e-02, -1.6578e-01, -1.8018e-02, -2.7667e-02, -1.3835e+00, 7.8008e-02, -7.0342e-01, 3.4551e-01, 5.7252e-01, 4.5663e-02, 5.2766e-02, 2.8974e-01, -3.4401e-01, 1.6897e-02, 9.7269e-02, -2.1634e-02, 7.9793e-01, 1.7612e-02, -3.2805e-03, -1.7782e-01, -1.4005e-01, 4.1215e-02, 7.2888e-01, -2.2417e-01, 1.9287e-03, 8.7772e-02, 1.3144e+00, -3.8825e-02])), ('bn1.running_var', tensor([ 5.0796e-01, 1.4441e+00, 3.3001e+00, 3.3098e+00, 1.3029e-01, 3.3023e+00, 1.2143e-01, 2.5986e-01, 8.9925e-02, 2.9480e+00, 1.3752e-01, 2.1341e+00, 6.9679e-02, 2.7234e-12, 2.4457e-02, 6.9063e-02, 1.1395e+00, 8.0611e-02, 2.1984e-02, 2.6701e+00, 5.6415e-02, 2.1792e-01, 1.0816e-01, 9.8851e-01, 3.0843e-01, 2.9959e+00, 5.4037e-02, 1.7887e-01, 2.8518e-02, 1.8343e+00, 7.0009e-01, 2.9475e-02, 1.1048e+01, 7.5987e-03, 2.6686e+00, 5.0308e+00, 2.8717e+00, 1.7434e+00, 3.8133e-01, 1.3055e-01, 8.6697e+00, 3.9596e-02, 2.3990e+00, 3.7014e+00, 6.9698e+00, 1.2682e-01, 1.4923e-01, 1.5581e+00, 1.1554e+00, 2.0051e-02, 1.3014e-01, 9.9781e-01, 3.6349e+00, 2.4568e-01, 1.2094e-01, 7.6329e-01, 7.9295e-01, 1.5916e-01, 3.8380e+00, 3.2014e-01, 3.4269e-01, 3.3512e-01, 8.0546e+00, 2.4255e-02])), ('layer1.0.conv1.weight', tensor([[[[ 3.5144e-03]], [[ 3.9855e-02]], [[-2.4795e-02]], ..., ('layer1.0.bn1.weight', tensor([ 2.1341e-01, 1.8848e-01, 1.4136e-01, 1.5273e-01, 1.3220e-01, 1.8735e-01, 1.4475e-01, 4.5110e-08, 1.5993e-01, 1.4946e-01, 2.3499e-01, 1.8315e-01, 1.8516e-01, 1.4933e-01, 1.3090e-01, 1.0634e-01, 3.7487e-01, 1.2644e-01, 3.1895e-01, 2.7160e-01, 2.5810e-01, 2.9458e-01, 1.8395e-01, 2.1088e-08, 3.3313e-01, 2.0461e-01, 3.0399e-01, 1.1805e-08, 1.4977e-01, 1.5719e-01, 1.4011e-01, 1.4900e-01, 1.2438e-01, 1.8786e-01, 1.4257e-01, 3.4828e-01, 1.5038e-01, 3.0034e-01, 2.5925e-01, 1.0711e-01, 2.6875e-01, 1.3552e-01, 1.1822e-01, 1.1189e-01, 2.8736e-01, 3.2637e-01, 1.4781e-01, 2.3105e-01, 3.3638e-01, 2.8808e-01, 1.2319e-01, 3.0763e-01, 1.1846e-01, 1.3137e-01, 2.0671e-01, 1.5787e-01, 2.6574e-08, 2.0467e-01, 2.8797e-08, 1.8284e-01, 3.0180e-01, 1.7401e-01, 2.8438e-01, 2.3715e-01])), ('layer1.0.bn1.bias', tensor([ 4.3266e-01, 4.6854e-02, -8.0134e-02, 7.3302e-02, 2.7970e-01, -7.8047e-03, 9.4087e-02, -1.0086e-07, -1.4034e-01, -5.1599e-02, 4.4470e-02, 2.1814e-01, 4.0718e-02, 1.1979e-01, 1.4432e-01, 1.3672e-01, -1.1168e-01, 1.4774e-01, -1.2879e-01, -5.3147e-02, -3.3920e-02, -2.0600e-02, 6.2783e-02, -6.5736e-08, -7.1213e-02, 6.9510e-02, -1.3264e-01, -6.4411e-08, -2.8908e-02, 9.4164e-02, 2.4790e-01, -8.2850e-02, -2.8872e-02, -1.7086e-01, 9.9522e-02, -1.1357e-01, 1.9770e-01, 1.4800e-02, -7.0896e-02, 1.0722e-01, 1.2536e-02, -3.6633e-02, 1.4959e-01, 1.0533e-01, 2.0933e-02, -1.0502e-01, -4.8848e-02, 4.9007e-01, -1.4755e-01, -1.0900e-01, 1.9815e-02, -7.0964e-02, -4.6543e-02, 1.0874e-01, -2.7878e-01, 4.4500e-03, -7.7156e-08, 7.5060e-02, -8.4474e-08, 2.2533e-01, -7.1593e-02, -1.5823e-01, -3.4459e-02, 5.2894e-01])), ('layer1.0.bn1.running_mean', tensor([-6.0619e-01, -3.5467e-01, 2.4651e-01, -2.5210e-01, -7.6892e-02, -3.3654e-01, -1.0111e-01, -1.7881e-08, 2.1631e-01, -2.8016e-01, -3.1948e-01, 1.1134e+00, -1.1791e-01, -2.0125e-01, -3.2957e-01, -2.6431e-02, -3.4833e-01, 7.1402e-01, -2.7727e-01, -2.7576e-01, -1.7791e-01, -1.1054e-01, -1.5952e-01, -5.6052e-45, -3.6867e-01, -1.7413e-01, -2.6344e-01, 4.3125e-09, -2.3616e-01, -3.0546e-01, -1.8908e-02, 2.2109e-01, 1.1146e-02, -1.4291e-01, -3.0156e-01, -4.4344e-01, -2.2829e-01, -2.0861e-01, -2.2197e-01, 3.1603e-01, -1.1507e-01, -1.3784e-01, -2.9271e-01, -4.8246e-01, -1.5741e-01, -2.6682e-01, -3.8136e-01, -3.1360e-01, -1.9755e-01, -4.1116e-01, -2.8717e-02, -3.0186e-01, 8.8766e-02, -3.3887e-01, -5.9848e-02, -6.4817e-01, 1.2924e-09, -2.2738e-01, -5.6052e-45, 1.0252e+00, -9.3871e-02, -1.4969e-02, -4.0218e-01, -1.3630e-01])), 指定模块load参数

方法1: 在日常实验中,你肯定遇到过这种情况——在你保存好你的模型参数后,你想整理下你模型里的模块的定义,删除一些冗余模块(上次训练没有用到的模块)。于是你删除了冗余模块 self.extra。 但是,当你再次load模型参数时,会报错缺少这一模块。为了能够再不重新训练的情况下顺利把之前的参数load进去,我们可以采取下面的做法。

import torch import torch.nn as nn class modelfunc(nn.Module): # 之前定义好的模型 def __init__(self, class_num): super(modelfunc, self).__init__() self.fc1 = nn.Linear(3,5) self.extra = nn.Linear(2,2) # 冗余模块 self.fc2 = nn.Linear(5,class_num) def forward(self,x): x = self.fc1(x) x = self.fc2(x) return x # 由于pytorch没有像keras那样有保存模型结构的API,因此,每次load之前必须找到模型的结构。 model_object = modelfunc(3) # 导入模型结构 # 仅保存和加载模型参数 # torch.save(model_object.state_dict(), 'params.pth')

保存好模型后,我们删除掉模型中的冗余模块 self.extra。为了在缺少冗余模块的情况下还能够顺利load之前的参数,我们采用下面的做法。

import torch import torch.nn as nn class modelfunc(nn.Module): # 之前定义好的模型 def __init__(self, class_num): super(modelfunc, self).__init__() self.fc1 = nn.Linear(3,5) # self.extra = nn.Linear(2,2) # 注释掉 冗余模块 self.fc2 = nn.Linear(5,class_num) def forward(self,x): x = self.fc1(x) x = self.fc2(x) return x # 由于pytorch没有像keras那样有保存模型结构的API,因此,每次load之前必须找到模型的结构。 model_object = modelfunc(3) # 导入模型结构 from collections import OrderedDict dict = torch.load('params.pth') new_dict = OrderedDict() for key in dict: if key in model_object.state_dict(): new_dict[key] = dict[key] model_object.load_state_dict(new_dict)

方法2:

model_object = modelfunc(3) # 导入模型结构 pretrain_dict =torch.load('params.pth') model_dict = model_object.state_dict() pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} model_dict.update(pretrain_dict) model_object.load_state_dict(model_dict)


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

      专题文章
        CopyRight 2018-2019 实验室设备网 版权所有